function savings_ev_quadrant_weighted = generate_results_quadrant_external_validity(bsimax,sample,sample_wave,target_wave,tcost,covariate,winter_prevuse,SMC,savedir,fbase)

% This function calculates cost savings or energy reduction from applying
% EWM quadrant rules across waves.  

% Inputs: 
% (1) bsimax: number of bootstrap runs
% (2) sample: cross-sectional dataset merged using elec_wave367_cross.csv
% and propensity scores and renamed as data_estimation_pooled_ps
% (3) sample_wave: indicates the wave the original EWM rule was estimated
% on (=3, 6, or 7)
% (4) target_wave: indicates the wave the EWM rule was applied to (=3, 6,
% or 7)
% (5) tcost: private marginal cost of implementing the program per household
%  >0 represents cost savings; =0 represents kWh reduction 
% (6) covariate: indicates covariate for analysis: income, size, vintage,
% minimum of baseline consumption, maximum of baseline consumption, or
% standard deviation of consumption
% (7) winter_prevuse: indicates whether baseline consumption is calculated 
% as the mean of consumption in winter months (Jan and Feb) or as the mean
% of specified pre-treatment periods.
% (8) SMC: =1 use social marginal cost; =0 use retail electricity price
% (9) savedir: directory to save the savings table 
% (10) fbase: string used to indicate the propensity score and baseline
% months specification for output table 

% Output: 
% (1) savings_ev_quadrant_weighted: a table summarizing the savings point estimate,
% share of households to treat, total savings for the entire state, and 95%
% confidence intervals for the point estimate 
%% Gather inputs

wave=0;
[X,Y,D,n,ps]=generate_inputs_quadrant_rule(sample,wave,tcost,covariate,SMC);

wave=3;
sample_cross_wave3=sample(sample.opower_paper_wave==3,:);
[X3,Y3,D3,n3,ps3]=generate_inputs_quadrant_rule(sample_cross_wave3,wave,tcost,covariate,SMC);
scaled_att_wave3= (sum(D3.*Y3-((1-D3).*Y3.*ps3)./(1-ps3))/sum(D3))*sum(D3)/n3;
wave=6;
sample_cross_wave6=sample(sample.opower_paper_wave==6,:);
[X6,Y6,D6,n6,ps6]=generate_inputs_quadrant_rule(sample_cross_wave6,wave,tcost,covariate,SMC);
scaled_att_wave6 = (sum(D6.*Y6-((1-D6).*Y6.*ps6)./(1-ps6))/sum(D6))*sum(D6)/n6;
wave=7;
sample_cross_wave7=sample(sample.opower_paper_wave==7,:);
[X7,Y7,D7,n7,ps7]=generate_inputs_quadrant_rule(sample_cross_wave7,wave,tcost,covariate,SMC);
scaled_att_wave7 = (sum(D7.*Y7-((1-D7).*Y7.*ps7)./(1-ps7))/sum(D7))*sum(D7)/n7;


%% Estimate wave-specific g using IPW
g3 = Y3.*(D3./ps3 - (1-D3)./(1-ps3)); 
g6 = Y6.*(D6./ps6 - (1-D6)./(1-ps6)); 
g7 = Y7.*(D7./ps7 - (1-D7)./(1-ps7)); 

g = Y.*(D./ps - (1-D)./(1-ps)); 

%% Pooled analysis re-weight waves by wave density ratio

% sort rows of X -> [Xr, gr]
[Xr, Ind] = sortrows(X);
gr = g(Ind);

[Xr3, Ind3] = sortrows(X3);
gr3 = g3(Ind3);
[Xr6, Ind6] = sortrows(X6);
gr6 = g6(Ind6);
[Xr7, Ind7] = sortrows(X7);
gr7 = g7(Ind7);

switch covariate
    case 'income'
        N_bins_x=29; % step size of 5000
        N_bins_y=272; % 32 bins of prev usage, 272 in the baseline specification
    case 'size'
        N_bins_x=48; % step size of 100
        N_bins_y=272; % 32 bins of prev usage, 272 in the baseline specification
    case 'vintage'
        N_bins_x=17; % step size of 10
        N_bins_y=272; % 32 bins of prev usage, 272 in the baseline specification
end

% generate grid lower boundusing the full sample
[N,boo1,boo2,xx1,xx2]=histcounts2(Xr(:,1),Xr(:,2),...
    [N_bins_x,N_bins_y],'Normalization','count'); 
%boo1: the lower bound of each bin for income
%boo2: the lower bound of each bin for prev_use
%xx1: each data point belongs to which income bin 
%xx2: each data point belongs to which prev_use bin 

nu = length(boo1)*length(boo2);
% number of unique grid point= number of income bin * number of prev_use

% Set the edges (lower bound) of each bin according to full sample boo1 and
% boo2, then decide households in each wave belong to which bin 
[N_3,boo1_3,boo2_3,xx1_3,xx2_3]=histcounts2(Xr3(:,1),Xr3(:,2),...
    'XBinEdges',boo1,'YBinEdges',boo2);
[N_6,boo1_6,boo2_6,xx1_6,xx2_6]=histcounts2(Xr6(:,1),Xr6(:,2),...
    'XBinEdges',boo1,'YBinEdges',boo2);
[N_7,boo1_7,boo2_7,xx1_7,xx2_7]=histcounts2(Xr7(:,1),Xr7(:,2),...
    'XBinEdges',boo1,'YBinEdges',boo2);

grid_x1_wave3=boo1;
grid_x1_wave6=boo1;
grid_x1_wave7=boo1;
grid_x2_wave3=boo2;
grid_x2_wave6=boo2;
grid_x2_wave7=boo2;

k1_wave3=N_bins_x+1;
k1_wave6=N_bins_x+1;
k1_wave7=N_bins_x+1;

k2_wave3=N_bins_y+1;
k2_wave6=N_bins_y+1;
k2_wave7=N_bins_y+1;

% All waves adjust to the pooled distribution of X (to avoid missing X
% values between two waves)
[Xu3,gu3,nw3,x1_wave3,x2_wave3] = empirical_density_generate_grid(nu,boo1,boo2,xx1_3,xx2_3,gr3);
[Xu6,gu6,nw6,x1_wave6,x2_wave6] = empirical_density_generate_grid(nu,boo1,boo2,xx1_6,xx2_6,gr6);
[Xu7,gu7,nw7,x1_wave7,x2_wave7] = empirical_density_generate_grid(nu,boo1,boo2,xx1_7,xx2_7,gr7);

%% Define X_target_wave and g_target_wave

% Define x1,x2,grid_x1,grid_x2 for target population, which will be used in
% the grid search part 
if target_wave==3
    x1_target_wave=x1_wave3;
    x2_target_wave=x2_wave3;
    grid_x1_target_wave=grid_x1_wave3;
    grid_x2_target_wave=grid_x2_wave3;
    k1_target_wave=k1_wave3;
    k2_target_wave=k2_wave3;
    Xu_target_wave=Xu3;
    nw_target_wave=nw3;
    n_target_wave=n3;
    gu_target_wave=gu3;
    scaled_att_target_wave=scaled_att_wave3;
    
elseif target_wave==6
    x1_target_wave=x1_wave6;
    x2_target_wave=x2_wave6;
    grid_x1_target_wave=grid_x1_wave6;
    grid_x2_target_wave=grid_x2_wave6;
    k1_target_wave=k1_wave6;
    k2_target_wave=k2_wave6;
    Xu_target_wave=Xu6;
    nw_target_wave=nw6;
    n_target_wave=n6;
    gu_target_wave=gu6;
    scaled_att_target_wave=scaled_att_wave6;
    
elseif target_wave==7
    x1_target_wave=x1_wave7;
    x2_target_wave=x2_wave7;
    grid_x1_target_wave=grid_x1_wave7;
    grid_x2_target_wave=grid_x2_wave7;
    k1_target_wave=k1_wave7;
    k2_target_wave=k2_wave7;
    Xu_target_wave=Xu7;
    nw_target_wave=nw7;
    n_target_wave=n7;
    gu_target_wave=gu7;
    scaled_att_target_wave=scaled_att_wave7;

end

if sample_wave==3
    Xu_sample_wave=Xu3;
    nw_sample_wave=nw3;
    n_sample_wave=n3;
    gu_sample_wave=gu3;
    scaled_att_sample_wave=scaled_att_wave3;

elseif sample_wave==6
    Xu_sample_wave=Xu6;
    nw_sample_wave=nw6;
    n_sample_wave=n6;
    gu_sample_wave=gu6;
    scaled_att_sample_wave=scaled_att_wave6;
    
elseif sample_wave==7
    Xu_sample_wave=Xu7;
    nw_sample_wave=nw7;
    n_sample_wave=n7;
    gu_sample_wave=gu7;
    scaled_att_sample_wave=scaled_att_wave7;
end

 if sample_wave==3 & target_wave==6 %only know X for the target population 
         % so calculate the density ratio of marginal distribution of X between 
         % sample population and target population, then re-weight sample population's gu 
         % to get target population's gu
        dr=((nw6)./n6)./(nw3./n3); 
        dr(isinf(dr))=0;
        dr(isnan(dr))=0;
        %dr_sample_wave_target_wave
        gu_w=gu3.*dr;
        % Re-weight gu of sample population to get the gu for the target
        % population, then calculate the ewm rules 
        
      elseif sample_wave==3 & target_wave==7 
        dr=((nw7)./n7)./(nw3./n3); 
        dr(isinf(dr))=0;
        dr(isnan(dr))=0;
        gu_w=gu3.*dr;
        
      elseif sample_wave==6 & target_wave==3  
        dr=((nw3)./n3)./(nw6./n6); 
        dr(isinf(dr))=0;
        dr(isnan(dr))=0;
        gu_w=gu6.*dr;         
      elseif sample_wave==6 & target_wave==7 
        dr=((nw7)./n7)./(nw6./n6); 
        dr(isinf(dr))=0;
        dr(isnan(dr))=0;
        gu_w=gu6.*dr;
      elseif sample_wave==7 & target_wave==3 
        dr=((nw3)./n3)./(nw7./n7); 
        dr(isinf(dr))=0;
        dr(isnan(dr))=0;
        gu_w=gu7.*dr;
      elseif sample_wave==7 & target_wave==6 
        dr=((nw6)./n6)./(nw7./n7);
        dr(isinf(dr))=0;
        dr(isnan(dr))=0;
        gu_w=gu7.*dr;
 end  
      
% use the same nu and boo1 and boo2 to gaurantee same grid space, xx1 and
% xx2 are wave-specific because they denote each data point belongs to
% which grid 

%% Get dr for each hh, not unique household 
if sample_wave==3
    dr_hh = nan(size(g3,1),1); % Density ratio in the order of sorted X in the pooled smaple (not just in the wave-specific sample)
    i_u=1;

        for i=1:length(boo1_3)
            for j=1:length(boo2_3)
                % individuals with the same (x1,x2)
                bin_boo = logical((xx1_3==i).*(xx2_3==j));
                % optimal treatment for people with (x1,x2)
                dr_ = dr(i_u);
                % let all individuals have this assignment
                dr_hh(bin_boo,1) = dr_;
                % step to the next element of Xu/gu/in_Ghat
                i_u=i_u+1;
            end
        end
    scaled_att_sample_wave_weighted= (sum(D3(Ind3).*Y3(Ind3).*dr_hh-((1-D3(Ind3)).*Y3(Ind3).*dr_hh.*ps3)./(1-ps3))/sum(D3(Ind3)))*sum(D3(Ind3))/n3; % Ind is the sorted order in the pooled sample
elseif sample_wave==6
    dr_hh = nan(size(g6,1),1); 
    i_u=1;

        for i=1:length(boo1_6)
            for j=1:length(boo2_6)
                % individuals with the same (x1,x2)
                bin_boo = logical((xx1_6==i).*(xx2_6==j));
                % optimal treatment for people with (x1,x2)
                dr_ = dr(i_u);
                % let all individuals have this assignment
                dr_hh(bin_boo,1) = dr_;
                % step to the next element of Xu/gu/in_Ghat
                i_u=i_u+1;
            end
        end
    scaled_att_sample_wave_weighted= (sum(D6(Ind6).*Y6(Ind6).*dr_hh-((1-D6(Ind6)).*Y6(Ind6).*dr_hh.*ps6)./(1-ps6))/sum(D6(Ind6)))*sum(D6(Ind6))/n3;
elseif sample_wave==7
    dr_hh = nan(size(g7,1),1); 
    i_u=1;

        for i=1:length(boo1_7)
            for j=1:length(boo2_7)
                % individuals with the same (x1,x2)
                bin_boo = logical((xx1_7==i).*(xx2_7==j));
                % optimal treatment for people with (x1,x2)
                dr_ = dr(i_u);
                % let all individuals have this assignment
                dr_hh(bin_boo,1) = dr_;
                % step to the next element of Xu/gu/in_Ghat
                i_u=i_u+1;
            end
        end
    scaled_att_sample_wave_weighted= (sum(D7(Ind7).*Y7(Ind7).*dr_hh-((1-D7(Ind7)).*Y7(Ind7).*dr_hh.*ps7)./(1-ps7))/sum(D7(Ind7)))*sum(D7(Ind7))/n7;
end

%% Minimize cost
tic1= tic;

m_W = Inf*ones(k1_target_wave,k2_target_wave);
v_x1 = grid_x1_target_wave;
v_x2 = grid_x2_target_wave;

minW = Inf;

for i1=1:k1_target_wave
    boo_i1=(x1_target_wave - grid_x1_target_wave(i1));
    for i2=1:k2_target_wave
        boo_i2=(x2_target_wave - grid_x2_target_wave(i2));
        for sign1 = -1:2:1 %for i = [-1,1]
            for sign2 = -1:2:1
                select = ((sign1.*boo_i1)>=0).*((sign2.*boo_i2)>=0); 
          
                %must meet both threshold conditions to be selected as treated
                W = sum(gu_w.*select);
                %
                if W<m_W(i1,i2)
                   m_W(i1,i2) = W; 
                end
                %
                if (W < minW)
                    min_i1 = i1;
                    min_i2 = i2;
                    min_sign1 = sign1;
                    min_sign2 = sign2;
                    minW = W;        
                end
            end
        end
    end
end

fprintf('Grid Serach takes %.2f sec\n',toc(tic1))

%% treatment status for each household 

in_Ghat = (min_sign1*(x1_target_wave - grid_x1_target_wave(min_i1))>=0).*(min_sign2*(x2_target_wave - grid_x2_target_wave(min_i2))>=0);
vhats_n=0;
vhats_p=0;

 %% Export rule parameters for graphs 

switch winter_prevuse 
    case 1
        s_winter = '_winter';
    case 0
        s_winter = '';
end
if (tcost)>0 % cost_savings
    if (SMC)
        s_cost = 'SMC';
    else
        s_cost = 'PMC';
    end
elseif (tcost)==0
    s_cost='kwh';
end

s_wave=sprintf('%d%d',sample_wave,target_wave);

filename_coefs=sprintf('coef_ev_quadrant_%s_%s%s_%s_weighted.mat',covariate,s_cost,s_winter,s_wave);
filename = sprintf('%s_%s',fbase,filename_coefs);
fpathf = fullfile(savedir,filename);
fprintf('Saving "%s" to "%s" ... ',filename,savedir); 
save(fpathf,'min_i1','min_i2','min_sign1','min_sign2','m_W',...
   'in_Ghat','v_x1','v_x2','Xu_target_wave','nw_target_wave','gu_w','gu_target_wave','n_target_wave','covariate','vhats_n','vhats_p','minW',...
   'dr','boo1','boo2','Xu_sample_wave','nw_sample_wave','n_sample_wave','gu_sample_wave','scaled_att_target_wave','scaled_att_sample_wave','scaled_att_sample_wave_weighted');
fprintf('   Done!\n');

%% Output for table 
percent=sum(nw_target_wave.*in_Ghat)/n_target_wave;
savings=sum(gu_target_wave.*in_Ghat)/n_target_wave;
if (tcost)
    total_savings=savings*n_target_wave*12;
else
    total_savings=savings*n_target_wave*12/1000;
end

ci_lb=(minW-prctile(vhats_abs,95))/n_target_wave;
ci_ub=(minW+prctile(vhats_abs,95))/n_target_wave;

savings_ev_quadrant_weighted=nan(1,6);
savings_ev_quadrant_weighted=array2table(savings_ev_quadrant_weighted,'VariableNames',{'covariate','percent','savings','totalsavings','cilb','ciub'}); 

savings_ev_quadrant_weighted.covariate = {covariate};
savings_ev_quadrant_weighted.percent=percent;
savings_ev_quadrant_weighted.savings=savings;
savings_ev_quadrant_weighted.totalsavings=total_savings;
savings_ev_quadrant_weighted.cilb=ci_lb;
savings_ev_quadrant_weighted.ciub=ci_ub;

end